import os
from transformer import make_model
from training import Batch, greedy_decode
import os
import yaml
import torch.nn.functional as F
import torch
from real_world_example import load_vocab, load_tokenizers, tokenize

spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)
config = yaml.safe_load(open('config.yaml', 'r'))
checkpoint_path = os.path.join(config['checkpoint_prefix'], 'transformer_model.pt')


def test_with_one_example(model, example_str, pad_idx=2, eos_string='</s>'):
    global spacy_de, spacy_en, vocab_src, vocab_tgt
    rb = Batch(example_str, pad=pad_idx)
    src_tokens = [
        vocab_src.get_itos()[x] for x in rb.src[0] if x != 2
    ]
    print(
        "Source Text(Input): " +
        " ".join(src_tokens).replace("\n", "")
    )

    model_out = greedy_decode(model, rb.src, rb.src_mask, max_len=72, start_symbol=0)[0]
    model_txt = (
            " ".join([vocab_tgt.get_itos()[x]
                      for x in model_out if x != pad_idx])
            .split(eos_string, 1)[0] + eos_string
    )
    print("Model Output: " + model_txt.replace("\n", ""))


def process_str(src_text: str, max_padding=128, pad_id=2):
    global spacy_de, vocab_src

    device = torch.device("cpu")
    bs_id = torch.tensor([0], device=device)  # <d> token id 句子起始标识
    eos_id = torch.tensor([1], device=device)  # </d> token id 句子结束标识
    processed_src = torch.cat(
        [
            bs_id,
            torch.tensor(
                vocab_src(tokenize(src_text, spacy_de)),
                dtype=torch.int64,
                device=device),
            eos_id
        ],
        dim=0)
    processed_src = F.pad(
        processed_src,
        pad=(0, max_padding - len(processed_src)),
        value=pad_id
    )

    return processed_src.unsqueeze(0)


if __name__ == "__main__":
    model = make_model(len(vocab_src), len(vocab_tgt), n=6)
    model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
    model.eval()
    example_str = "Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen"
    processed_src = process_str(example_str)
    test_with_one_example(model, processed_src)
